import argparse
import functools
import pickle
from pathlib import Path

import jax
import jax.numpy as jnp
import yaml
from flax import serialization
from omegaconf import OmegaConf

from analysis.data_utils.analysis_repertoire import AnalysisLatentRepertoire
from baselines.qdax import environments
from baselines.qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig
from baselines.qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from baselines.qdax.core.neuroevolution.buffers.buffer import QDTransition
from baselines.qdax.environments import get_feat_mean
from baselines.qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function


def get_repertoire_smerl(path_results_to_load, number_visits_per_goal, is_reversed, resolution):
  path_results_to_load = Path(path_results_to_load)
  config = OmegaConf.load(path_results_to_load / ".hydra" / "config.yaml")

  # Init a random key
  random_key = jax.random.PRNGKey(config.seed)

  # Init environment
  env = environments.create(config.task + "_" + config.feat,
                            episode_length=config.algo.episode_length,
                            backend=config.algo.backend,
                            )

  # Define config
  smerl_config = DiaynSmerlConfig(
    # SAC config
    batch_size=config.algo.batch_size,
    episode_length=config.algo.episode_length,
    tau=config.algo.soft_tau_update,
    normalize_observations=config.algo.normalize_observations,
    learning_rate=config.algo.learning_rate,
    alpha_init=config.algo.alpha_init,
    discount=config.algo.discount,
    reward_scaling=config.algo.reward_scaling,
    hidden_layer_sizes=config.algo.hidden_layer_sizes,
    fix_alpha=config.algo.fix_alpha,
    # DIAYN config
    skill_type=config.algo.skill_type,
    num_skills=config.algo.num_skills,
    descriptor_full_state=config.algo.descriptor_full_state,
    extrinsic_reward=False,
    beta=1.,
    # SMERL
    reverse=is_reversed,
    diversity_reward_scale=config.algo.diversity_reward_scale,
    smerl_target=config.algo.smerl_target,
    smerl_margin=config.algo.smerl_margin,
  )

  # Define an instance of DIAYN
  smerl = DIAYNSMERL(config=smerl_config, action_size=env.action_size)

  random_key, random_subkey = jax.random.split(random_key)
  fake_obs = jnp.zeros((env.observation_size + config.algo.num_skills,))
  fake_goal = jnp.zeros((config.algo.num_skills,))
  fake_actor_params = smerl._policy.init(random_subkey, fake_obs)
  fake_discriminator_params = smerl._discriminator.init(random_subkey, fake_goal)

  with open(path_results_to_load / "actor/actor.pickle", "rb") as params_file:
    state_dict = pickle.load(params_file)
  actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

  with open(path_results_to_load / "discriminator/discriminator.pickle", "rb") as params_file:
    state_dict = pickle.load(params_file)
  discriminator_params = serialization.from_state_dict(fake_discriminator_params, state_dict)

  # Create grid
  grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
  goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
  latent_goals, _ = smerl._discriminator.apply(discriminator_params, goals)

  reset_fn = jax.jit(env.reset)

  @jax.jit
  def play_step_fn(env_state, params, latent_goal, random_key):
    actions, random_key = smerl.select_action(
      obs=jnp.concatenate([env_state.obs, latent_goal], axis=0),
      policy_params=params,
      random_key=random_key,
      deterministic=True,
    )
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
      obs=env_state.obs,
      next_obs=next_state.obs,
      rewards=next_state.reward,
      dones=next_state.done,
      truncations=next_state.info["truncation"],
      actions=actions,
      state_desc=state_desc,
      next_state_desc=next_state.info["state_descriptor"],
      desc=jnp.zeros(env.behavior_descriptor_length ,) * jnp.nan,
      desc_prime=jnp.zeros(env.behavior_descriptor_length ,) * jnp.nan,
    )

    return next_state, params, latent_goal, random_key, transition

  # Prepare the scoring function
  scoring_fn = jax.jit(functools.partial(
    scoring_actor_dc_function,
    episode_length=config.algo.episode_length,
    play_reset_fn=reset_fn,
    play_step_actor_dc_fn=play_step_fn,
    behavior_descriptor_extractor=get_feat_mean,
  ))

  @jax.jit
  def evaluate_actor(random_key, params, latent_goals):
    params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), latent_goals.shape[0], axis=0), params)
    fitnesses, descriptors, extra_scores, random_key = scoring_fn(
      params, latent_goals, random_key
    )
    return fitnesses, descriptors, extra_scores, random_key

  fitnesses_list = []
  descriptor_list = []
  for index_visit in range(number_visits_per_goal):
    print(f"Smerl: Visit {index_visit} / {number_visits_per_goal}")
    fitnesses, descriptors, extra_scores, random_key = evaluate_actor(random_key, actor_params, latent_goals)
    fitnesses_list.append(fitnesses)
    descriptor_list.append(descriptors)

  smerl_repertoire = AnalysisLatentRepertoire(
    centroids=goals,
    latent_goals=latent_goals,
    fitnesses=jnp.stack(fitnesses_list, axis=1),
    descriptors=jnp.stack(descriptor_list, axis=1))

  return smerl_repertoire


def evaluate_and_save_smerl(path_results_to_load, path_results_to_save, number_visits_per_goal, is_reversed, resolution):

  repertoire = get_repertoire_smerl(path_results_to_load, number_visits_per_goal, is_reversed=is_reversed, resolution=resolution)
  path_save_repertoire = Path(path_results_to_save) / "analysis_repertoire.pkl"
  with open(path_save_repertoire, "wb") as file:
    pickle.dump(repertoire, file)


def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument('-l', "--path-load", type=str, required=True)
  parser.add_argument('-s', "--path-save", type=str, required=True)
  parser.add_argument('-n', "--num-reevals", type=int)
  parser.add_argument('-r', "--is-reversed", action="store_true")
  return parser.parse_args()


def main():
  args = get_args()
  evaluate_and_save_smerl(args.path_load, args.path_save, args.num_reevals, args.is_reversed)


if __name__ == "__main__":
  main()
